{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pnn4rDWGqDZL"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "l534d35Gp68G"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3TI3Q3XBesaS"
      },
      "source": [
        "# Training checkpoints"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yw_a0iGucY8z"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/checkpoint\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/checkpoint.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/checkpoint.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/checkpoint.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LeDp7dovcbus"
      },
      "source": [
        "The phrase \"Saving a TensorFlow model\" typically means one of two things:\n",
        "\n",
        "  1. Checkpoints, OR \n",
        "  2. SavedModel.\n",
        "\n",
        "Checkpoints capture the exact value of all parameters (`tf.Variable` objects) used by a model. Checkpoints do not contain any description of the computation defined by the model and thus are typically only useful when source code that will use the saved parameter values is available.\n",
        "\n",
        "The SavedModel format on the other hand includes a serialized description of the computation defined by the model in addition to the parameter values (checkpoint). Models in this format are independent of the source code that created the model. They are thus suitable for deployment via TensorFlow Serving, TensorFlow Lite, TensorFlow.js, or programs in other programming languages (the C, C++, Java, Go, Rust, C# etc. TensorFlow APIs).\n",
        "\n",
        "This guide covers APIs for writing and reading checkpoints."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U0nm8k-6xfh2"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VEvpMYAKsC4z"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OEQCseyeC4Ev"
      },
      "outputs": [],
      "source": [
        "class Net(tf.keras.Model):\n",
        "  \"\"\"A simple linear model.\"\"\"\n",
        "\n",
        "  def __init__(self):\n",
        "    super(Net, self).__init__()\n",
        "    self.l1 = tf.keras.layers.Dense(5)\n",
        "\n",
        "  def call(self, x):\n",
        "    return self.l1(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "utqeoDADC5ZR"
      },
      "outputs": [],
      "source": [
        "net = Net()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5vsq3-pffo1I"
      },
      "source": [
        "## Saving from `tf.keras` training APIs\n",
        "\n",
        "See the [`tf.keras` guide on saving and\n",
        "restoring](https://www.tensorflow.org/guide/keras/save_and_serialize).\n",
        "\n",
        "`tf.keras.Model.save_weights` saves a TensorFlow checkpoint. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SuhmrYPEl4D_"
      },
      "outputs": [],
      "source": [
        "net.save_weights('easy_checkpoint')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XseWX5jDg4lQ"
      },
      "source": [
        "## Writing checkpoints\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1jpZPz76ZP3K"
      },
      "source": [
        "The persistent state of a TensorFlow model is stored in `tf.Variable` objects. These can be constructed directly, but are often created through high-level APIs like `tf.keras.layers` or `tf.keras.Model`.\n",
        "\n",
        "The easiest way to manage variables is by attaching them to Python objects, then referencing those objects. \n",
        "\n",
        "Subclasses of `tf.train.Checkpoint`, `tf.keras.layers.Layer`, and `tf.keras.Model` automatically track variables assigned to their attributes. The following example constructs a simple linear model, then writes checkpoints which contain values for all of the model's variables."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x0vFBr_Im73_"
      },
      "source": [
        "You can easily save a model-checkpoint with `Model.save_weights`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FHTJ1JzxCi8a"
      },
      "source": [
        "### Manual checkpointing"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6cF9fqYOCrEO"
      },
      "source": [
        "#### Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fNjf9KaLdIRP"
      },
      "source": [
        "To help demonstrate all the features of `tf.train.Checkpoint`, define a toy dataset and optimization step:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tSNyP4IJ9nkU"
      },
      "outputs": [],
      "source": [
        "def toy_dataset():\n",
        "  inputs = tf.range(10.)[:, None]\n",
        "  labels = inputs * 5. + tf.range(5.)[None, :]\n",
        "  return tf.data.Dataset.from_tensor_slices(\n",
        "    dict(x=inputs, y=labels)).repeat().batch(2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ICm1cufh_JH8"
      },
      "outputs": [],
      "source": [
        "def train_step(net, example, optimizer):\n",
        "  \"\"\"Trains `net` on `example` using `optimizer`.\"\"\"\n",
        "  with tf.GradientTape() as tape:\n",
        "    output = net(example['x'])\n",
        "    loss = tf.reduce_mean(tf.abs(output - example['y']))\n",
        "  variables = net.trainable_variables\n",
        "  gradients = tape.gradient(loss, variables)\n",
        "  optimizer.apply_gradients(zip(gradients, variables))\n",
        "  return loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vxzGpHRbOVO6"
      },
      "source": [
        "#### Create the checkpoint objects\n",
        "\n",
        "Use a `tf.train.Checkpoint` object to manually create a checkpoint, where the objects you want to checkpoint are set as attributes on the object.\n",
        "\n",
        "A `tf.train.CheckpointManager` can also be helpful for managing multiple checkpoints."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ou5qarOQOWYl"
      },
      "outputs": [],
      "source": [
        "opt = tf.keras.optimizers.Adam(0.1)\n",
        "dataset = toy_dataset()\n",
        "iterator = iter(dataset)\n",
        "ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)\n",
        "manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8ZbYSD4uCy96"
      },
      "source": [
        "#### Train and checkpoint the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NP9IySmCeCkn"
      },
      "source": [
        "The following training loop creates an instance of the model and of an optimizer, then gathers them into a `tf.train.Checkpoint` object. It calls the training step in a loop on each batch of data, and periodically writes checkpoints to disk."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BbCS5A6K1VSH"
      },
      "outputs": [],
      "source": [
        "def train_and_checkpoint(net, manager):\n",
        "  ckpt.restore(manager.latest_checkpoint)\n",
        "  if manager.latest_checkpoint:\n",
        "    print(\"Restored from {}\".format(manager.latest_checkpoint))\n",
        "  else:\n",
        "    print(\"Initializing from scratch.\")\n",
        "\n",
        "  for _ in range(50):\n",
        "    example = next(iterator)\n",
        "    loss = train_step(net, example, opt)\n",
        "    ckpt.step.assign_add(1)\n",
        "    if int(ckpt.step) % 10 == 0:\n",
        "      save_path = manager.save()\n",
        "      print(\"Saved checkpoint for step {}: {}\".format(int(ckpt.step), save_path))\n",
        "      print(\"loss {:1.2f}\".format(loss.numpy()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ik3IBMTdPW41"
      },
      "outputs": [],
      "source": [
        "train_and_checkpoint(net, manager)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2wzcc1xYN-sH"
      },
      "source": [
        "#### Restore and continue training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lw1QeyRBgsLE"
      },
      "source": [
        "After the first training cycle you can pass a new model and manager, but pick up training exactly where you left off:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UjilkTOV2PBK"
      },
      "outputs": [],
      "source": [
        "opt = tf.keras.optimizers.Adam(0.1)\n",
        "net = Net()\n",
        "dataset = toy_dataset()\n",
        "iterator = iter(dataset)\n",
        "ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)\n",
        "manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)\n",
        "\n",
        "train_and_checkpoint(net, manager)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dxJT9vV-2PnZ"
      },
      "source": [
        "The `tf.train.CheckpointManager` object deletes old checkpoints. Above it's configured to keep only the three most recent checkpoints."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3zmM0a-F5XqC"
      },
      "outputs": [],
      "source": [
        "print(manager.checkpoints)  # List the three remaining checkpoints"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qwlYDyjemY4P"
      },
      "source": [
        "These paths, e.g. `'./tf_ckpts/ckpt-10'`, are not files on disk. Instead they are prefixes for an `index` file and one or more data files which contain the variable values. These prefixes are grouped together in a single `checkpoint` file (`'./tf_ckpts/checkpoint'`) where the `CheckpointManager` saves its state."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "t1feej9JntV_"
      },
      "outputs": [],
      "source": [
        "!ls ./tf_ckpts"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DR2wQc9x6b3X"
      },
      "source": [
        "<a id=\"loading_mechanics\"/>\n",
        "\n",
        "## Loading mechanics\n",
        "\n",
        "TensorFlow matches variables to checkpointed values by traversing a directed graph with named edges, starting from the object being loaded. Edge names typically come from attribute names in objects, for example the `\"l1\"` in `self.l1 = tf.keras.layers.Dense(5)`. `tf.train.Checkpoint` uses its keyword argument names, as in the `\"step\"` in `tf.train.Checkpoint(step=...)`.\n",
        "\n",
        "The dependency graph from the example above looks like this:\n",
        "\n",
        "![Visualization of the dependency graph for the example training loop](https://tensorflow.org/images/guide/whole_checkpoint.svg)\n",
        "\n",
        "The optimizer is in red, regular variables are in blue, and the optimizer slot variables are in orange. The other nodes—for example, representing the `tf.train.Checkpoint`—are in black.\n",
        "\n",
        "Slot variables are part of the optimizer's state, but are created for a specific variable. For example, the `'m'` edges above correspond to momentum, which the Adam optimizer tracks for each variable. Slot variables are only saved in a checkpoint if the variable and the optimizer would both be saved, thus the dashed edges."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VpY5IuanUEQ0"
      },
      "source": [
        "Calling `restore` on a `tf.train.Checkpoint` object queues the requested restorations, restoring variable values as soon as there's a matching path from the `Checkpoint` object. For example, you can load just the bias from the model you defined above by reconstructing one path to it through the network and the layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wmX2AuyH7TVt"
      },
      "outputs": [],
      "source": [
        "to_restore = tf.Variable(tf.zeros([5]))\n",
        "print(to_restore.numpy())  # All zeros\n",
        "fake_layer = tf.train.Checkpoint(bias=to_restore)\n",
        "fake_net = tf.train.Checkpoint(l1=fake_layer)\n",
        "new_root = tf.train.Checkpoint(net=fake_net)\n",
        "status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))\n",
        "print(to_restore.numpy())  # This gets the restored value."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GqEW-_pJDAnE"
      },
      "source": [
        "The dependency graph for these new objects is a much smaller subgraph of the larger checkpoint you wrote above. It includes only the bias and a save counter that `tf.train.Checkpoint` uses to number checkpoints.\n",
        "\n",
        "![Visualization of a subgraph for the bias variable](https://tensorflow.org/images/guide/partial_checkpoint.svg)\n",
        "\n",
        "`restore` returns a status object, which has optional assertions. All of the objects created in the new `Checkpoint` have been restored, so `status.assert_existing_objects_matched` passes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P9TQXl81Dq5r"
      },
      "outputs": [],
      "source": [
        "status.assert_existing_objects_matched()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GoMwf8CFDu9r"
      },
      "source": [
        "There are many objects in the checkpoint which haven't matched, including the layer's kernel and the optimizer's variables. `status.assert_consumed` only passes if the checkpoint and the program match exactly, and would throw an exception here."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KCcmJ-2j9RUP"
      },
      "source": [
        "### Deferred restorations\n",
        "\n",
        "`Layer` objects in TensorFlow may defer the creation of variables to their first call, when input shapes are available. For example, the shape of a `Dense` layer's kernel depends on both the layer's input and output shapes, and so the output shape required as a constructor argument is not enough information to create the variable on its own. Since calling a `Layer` also reads the variable's value, a restore must happen between the variable's creation and its first use.\n",
        "\n",
        "To support this idiom, `tf.train.Checkpoint` defers restores which don't yet have a matching variable."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TXYUCO3v-I72"
      },
      "outputs": [],
      "source": [
        "deferred_restore = tf.Variable(tf.zeros([1, 5]))\n",
        "print(deferred_restore.numpy())  # Not restored; still zeros\n",
        "fake_layer.kernel = deferred_restore\n",
        "print(deferred_restore.numpy())  # Restored"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-DWhJ3glyobN"
      },
      "source": [
        "### Manually inspecting checkpoints\n",
        "\n",
        "`tf.train.load_checkpoint` returns a `CheckpointReader` that gives lower level access to the checkpoint contents. It contains mappings from each variable's key, to the shape and dtype for each variable in the checkpoint. A variable's key is its object path, like in the graphs displayed above.\n",
        "\n",
        "Note: There is no higher level structure to the checkpoint. It only know's the paths and values for the variables, and has no concept of `models`, `layers` or how they are connected."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RlRsADTezoBD"
      },
      "outputs": [],
      "source": [
        "reader = tf.train.load_checkpoint('./tf_ckpts/')\n",
        "shape_from_key = reader.get_variable_to_shape_map()\n",
        "dtype_from_key = reader.get_variable_to_dtype_map()\n",
        "\n",
        "sorted(shape_from_key.keys())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AVrdvbNvgq5V"
      },
      "source": [
        "So if you're interested in the value of `net.l1.kernel` you can get the value with the following code:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lYhX_XWCgl92"
      },
      "outputs": [],
      "source": [
        "key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'\n",
        "\n",
        "print(\"Shape:\", shape_from_key[key])\n",
        "print(\"Dtype:\", dtype_from_key[key].name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2Zk92jM5gRDW"
      },
      "source": [
        "It also provides a `get_tensor` method allowing you to inspect the value of a variable:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cDJO3cgmecvi"
      },
      "outputs": [],
      "source": [
        "reader.get_tensor(key)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5fxk_BnZ4W1b"
      },
      "source": [
        "### Object tracking\n",
        "\n",
        "Checkpoints save and restore the values of `tf.Variable` objects by \"tracking\" any variable or trackable object set in one of its attributes.  When executing a save, variables are gathered recursively from all of the reachable tracked objects.\n",
        "\n",
        "As with direct attribute assignments like `self.l1 = tf.keras.layers.Dense(5)`, assigning lists and dictionaries to attributes will track their contents."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rfaIbDtDHAr_"
      },
      "outputs": [],
      "source": [
        "save = tf.train.Checkpoint()\n",
        "save.listed = [tf.Variable(1.)]\n",
        "save.listed.append(tf.Variable(2.))\n",
        "save.mapped = {'one': save.listed[0]}\n",
        "save.mapped['two'] = save.listed[1]\n",
        "save_path = save.save('./tf_list_example')\n",
        "\n",
        "restore = tf.train.Checkpoint()\n",
        "v2 = tf.Variable(0.)\n",
        "assert 0. == v2.numpy()  # Not restored yet\n",
        "restore.mapped = {'two': v2}\n",
        "restore.restore(save_path)\n",
        "assert 2. == v2.numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UTKvbxHcI3T2"
      },
      "source": [
        "You may notice wrapper objects for lists and dictionaries. These wrappers are checkpointable versions of the underlying data-structures. Just like the attribute based loading, these wrappers restore a variable's value as soon as it's added to the container."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s0Uq1Hv5JCmm"
      },
      "outputs": [],
      "source": [
        "restore.listed = []\n",
        "print(restore.listed)  # ListWrapper([])\n",
        "v1 = tf.Variable(0.)\n",
        "restore.listed.append(v1)  # Restores v1, from restore() in the previous cell\n",
        "assert 1. == v1.numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OxCIf2J6JyQ8"
      },
      "source": [
        "Trackable objects include `tf.train.Checkpoint`, `tf.Module` and its subclasses (e.g. `keras.layers.Layer` and `keras.Model`), and recognized Python containers:\n",
        "\n",
        " * `dict` (and `collections.OrderedDict`)\n",
        " * `list`\n",
        " * `tuple` (and `collections.namedtuple`, `typing.NamedTuple`)\n",
        "\n",
        "Other container types are **not supported**, including:\n",
        "\n",
        " * `collections.defaultdict`\n",
        " * `set`\n",
        "\n",
        "All other Python objects are **ignored**, including:\n",
        "\n",
        " * `int`\n",
        " * `string`\n",
        " * `float`\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "knyUFMrJg8y4"
      },
      "source": [
        "## Summary\n",
        "\n",
        "TensorFlow objects provide an easy automatic mechanism for saving and restoring the values of variables they use.\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "checkpoint.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
